I.Project Overview
II.Exploratory Data Analysis
III.Data Preprocessing
IV.Methodology
V.Training Models
VI.Results
VII.Conclusions
VIII.References
A stroke is a medical condition, which occurs when "blood flow to the brain is blocked. This prevents the brain from getting oxygen and nutrients from the blood. Without oxygen and nutrients, brain cells begin to die within minutes. Sudden bleeding in the brain can also cause a stroke if it damages brain cells" (National Heart, Lung, and Blood Institute, n.d.).
Main causes of the stroke are high blood pressure, diabetes and smoking.
The main purpose of the project is to predict whether there is risk of the stroke using the dataset, which is provided by Federico Soriano (2021) on Kaggle.
Strategy to solve the problem
To solve this problem, I implemented five machine learning models (k-nearest neighbors classifier, decision tree, random forest, gaussian naive bayes and c-support vector classification) and compared their performance. The purpose of the comparison, is to identify the model which shows the best performance in terms of predicting the probability of the stroke. The anticipated solution is the model which demonstrates acceptable performance based on the metrics listed in the section below, in particular F1-score.
Metrics
To evaluate the machine learning models, I used F1-score metric. F1-score is the harmonic mean of recall and precision. If F1-score of several models is the samel, we will choose the model with highest recall, because the model’s performance in the case of the stroke prediction should be minimized in terms of predicting the false negative cases . This is because in the case of determining the probability of the stroke, which is dependent largely on the patient’s health behavior, it is important to identify all persons that have risk of a stroke. For this specific problem, we can include false positives, which are the persons with low risk of stroke, because it will lead only to further examinations of health conditions (Sunastra, 2017).
Overview of the dataset, provided by the author:
1) id: unique identifier
2) gender: "Male", "Female" or "Other"
3) age: age of the patient
4) hypertension: 0 if the patient doesn't have hypertension, 1 if the patient has hypertension
5) heart_disease: 0 if the patient doesn't have any heart diseases, 1 if the patient has a heart disease
6) ever_married: "No" or "Yes"
7) work_type: "children", "Govt_jov", "Never_worked", "Private" or "Self-employed"
8) Residence_type: "Rural" or "Urban"
9) avg_glucose_level: average glucose level in blood
10) bmi: body mass index
11) smoking_status: "formerly smoked", "never smoked", "smokes" or "Unknown" (which means that the information is unavailable for this patient)
12) stroke: 1 if the patient had a stroke or 0 if not
# Import all libraries:
import matplotlib.pyplot as plt
import plotly
from plotly.subplots import make_subplots
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')
#Read the csv:
data = pd.read_csv('healthcare-dataset-stroke-data.csv')
#Full path to the file, if needed
# Overview of dataset:
data
| id | gender | age | hypertension | heart_disease | ever_married | work_type | Residence_type | avg_glucose_level | bmi | smoking_status | stroke | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 9046 | Male | 67.0 | 0 | 1 | Yes | Private | Urban | 228.69 | 36.6 | formerly smoked | 1 |
| 1 | 51676 | Female | 61.0 | 0 | 0 | Yes | Self-employed | Rural | 202.21 | NaN | never smoked | 1 |
| 2 | 31112 | Male | 80.0 | 0 | 1 | Yes | Private | Rural | 105.92 | 32.5 | never smoked | 1 |
| 3 | 60182 | Female | 49.0 | 0 | 0 | Yes | Private | Urban | 171.23 | 34.4 | smokes | 1 |
| 4 | 1665 | Female | 79.0 | 1 | 0 | Yes | Self-employed | Rural | 174.12 | 24.0 | never smoked | 1 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 5105 | 18234 | Female | 80.0 | 1 | 0 | Yes | Private | Urban | 83.75 | NaN | never smoked | 0 |
| 5106 | 44873 | Female | 81.0 | 0 | 0 | Yes | Self-employed | Urban | 125.20 | 40.0 | never smoked | 0 |
| 5107 | 19723 | Female | 35.0 | 0 | 0 | Yes | Self-employed | Rural | 82.99 | 30.6 | never smoked | 0 |
| 5108 | 37544 | Male | 51.0 | 0 | 0 | Yes | Private | Rural | 166.29 | 25.6 | formerly smoked | 0 |
| 5109 | 44679 | Female | 44.0 | 0 | 0 | Yes | Govt_job | Urban | 85.28 | 26.2 | Unknown | 0 |
5110 rows × 12 columns
data.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 5110 entries, 0 to 5109 Data columns (total 12 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 id 5110 non-null int64 1 gender 5110 non-null object 2 age 5110 non-null float64 3 hypertension 5110 non-null int64 4 heart_disease 5110 non-null int64 5 ever_married 5110 non-null object 6 work_type 5110 non-null object 7 Residence_type 5110 non-null object 8 avg_glucose_level 5110 non-null float64 9 bmi 4909 non-null float64 10 smoking_status 5110 non-null object 11 stroke 5110 non-null int64 dtypes: float64(3), int64(4), object(5) memory usage: 479.2+ KB
# Check for missing data:
data.isnull().sum()
id 0 gender 0 age 0 hypertension 0 heart_disease 0 ever_married 0 work_type 0 Residence_type 0 avg_glucose_level 0 bmi 201 smoking_status 0 stroke 0 dtype: int64
Only one column has NaN values.
Let us check the proportion of missing data:
data['bmi'].isnull().sum()/data.shape[0]
0.03933463796477495
3% of data in "BMI" column is missing.
Let us start with the general overview and findings.
# Add a column for graphs:
data['Stroke?'] = data['stroke']==1
data['Stroke?']= data['Stroke?'].replace({True: 'Yes', False: 'No'})
#Rename job types:
data['work_type']= data['work_type'].str.replace("children", "Parental_leave")
data['work_type']= data['work_type'].str.replace("Govt_job", "Goverment_job")
data['work_type']= data['work_type'].str.replace("Self-employed", "Self_employed")
# Check variable corellations:
data.corr()
| id | age | hypertension | heart_disease | avg_glucose_level | bmi | stroke | |
|---|---|---|---|---|---|---|---|
| id | 1.000000 | 0.003538 | 0.003550 | -0.001296 | 0.001092 | 0.003084 | 0.006388 |
| age | 0.003538 | 1.000000 | 0.276398 | 0.263796 | 0.238171 | 0.333398 | 0.245257 |
| hypertension | 0.003550 | 0.276398 | 1.000000 | 0.108306 | 0.174474 | 0.167811 | 0.127904 |
| heart_disease | -0.001296 | 0.263796 | 0.108306 | 1.000000 | 0.161857 | 0.041357 | 0.134914 |
| avg_glucose_level | 0.001092 | 0.238171 | 0.174474 | 0.161857 | 1.000000 | 0.175502 | 0.131945 |
| bmi | 0.003084 | 0.333398 | 0.167811 | 0.041357 | 0.175502 | 1.000000 | 0.042374 |
| stroke | 0.006388 | 0.245257 | 0.127904 | 0.134914 | 0.131945 | 0.042374 | 1.000000 |
fig = px.histogram(data, x='age',
nbins=60,
title='Age distribution', width=800, height=600,
color_discrete_sequence=px.colors.qualitative.Set1,
color='Stroke?', labels={'age':'Age'})
fig.update_layout(bargap=0.1)
Here we see some outliers, but we can say that the stroke primarily affects elder people.
fig = px.histogram(data, x='avg_glucose_level',
nbins=150,
title='Glucose level distribution', width=800, height=600,
color_discrete_sequence=px.colors.qualitative.Set3,
#marginal='box',
color='Stroke?', labels={'avg_glucose_level':'Average glucose level'})
fig.add_vline(x=100, line_width=1, line_dash="dash", line_color="green", annotation_text="Normal glucose level")
fig.update_layout(bargap=0.1)
We know that the blood sugar level test has the following results (Mayo Clinic, n.d.):
fig = px.histogram(data, x='bmi', width=800, height=600,
nbins=150,
title='BMI distribution',
color_discrete_sequence=px.colors.qualitative.Safe,
color='Stroke?', labels={'bmi':'Body mass index', 'count':'Count'} )
fig.add_vrect(x0="0", x1="18.5", annotation_text="Underweight", annotation_position="top left",
fillcolor="green", opacity=0.05, line_width=0)
fig.add_vrect(x0="25", x1="29.9", annotation_text="Pre-obesity", annotation_position="top left",
fillcolor="green", opacity=0.09, line_width=0.01)
fig.update_layout(bargap=0.1)
Body mass index can be measured as following:
Other healthfactors:
def show_graph(parameter, row, column):
'''
This function shows graphs.
Parameters:
parameter: name of feature;
row: row
column: column
Returns:
graph with bars
'''
stroke_parameter = data.groupby(parameter)['stroke'].mean()
other = 1 - data.groupby(parameter)['stroke'].mean().values
fig.add_trace(go.Bar(x = data[parameter].value_counts().index, y= data[parameter].value_counts().values*other,
text=other,texttemplate='No stroke : %{text:.2f}',textposition='auto'),row, column)
fig.add_trace(go.Bar(x = data[parameter].value_counts().index, y= data[parameter].value_counts().values * data.groupby(parameter)['stroke'].mean().values,
text=stroke_parameter,texttemplate='Stroke : %{text:.2f}', textposition='auto'), row, column)
fig = make_subplots(rows=1, cols=3, shared_yaxes=True,subplot_titles=("Hypertension",
"Heart disease",
"Smoking"),vertical_spacing=0.1)
show_graph('hypertension',1,1)
show_graph('heart_disease',1,2)
show_graph('smoking_status',1,3)
fig.update_layout(barmode='group')
fig.update_layout(height = 400, showlegend=False)
fig.update_layout(paper_bgcolor=px.colors.qualitative.Pastel2[6],bargap=0.2)
fig.show()
We can say that the presence of both hypertention (1) and heart desease (1) is also a risk factor. Smokers, even the former ones, have more chances to get a stroke.
Other factors:
def show_graph(parameter, row, column):
'''
This function shows graphs.
Parameters:
parameter: name of feature;
row: row
column: column
Returns:
graph with bars
'''
stroke_parameter = data.groupby(parameter)['stroke'].mean()
other = 1 - data.groupby(parameter)['stroke'].mean().values
fig.add_trace(go.Bar(x = data[parameter].value_counts().index, y= data[parameter].value_counts().values*other,
text=other,texttemplate='No stroke : %{text:.2f}',textposition='auto'),row, column)
fig.add_trace(go.Bar(x = data[parameter].value_counts().index, y= data[parameter].value_counts().values * data.groupby(parameter)['stroke'].mean().values,
text=stroke_parameter,texttemplate='Stroke : %{text:.2f}', textposition='auto'), row, column)
fig = make_subplots(rows=1, cols=4, shared_yaxes=True,subplot_titles=("Gender","Married",
"Work type","Residence type"),vertical_spacing=0.1)
show_graph('gender',1,1)
show_graph('ever_married',1,2)
show_graph('work_type',1,3)
show_graph('Residence_type',1,4)
fig.update_layout(barmode='group')
fig.update_layout(height = 400, showlegend=False)
fig.update_layout(paper_bgcolor=px.colors.qualitative.Pastel2[6],bargap=0.2)
fig.show()
Based on these plots, we can say that both males and females suffer from strokes.
Not married people are definitely at risk. We can suggest that self-employed people have less stress and are not in the risk group. Type of the residence has no effect on the probability of getting a stroke.
data['bmi'] = data['bmi'].fillna((data['bmi'].mean()))
# Check for missing data:
data.isnull().sum()
id 0 gender 0 age 0 hypertension 0 heart_disease 0 ever_married 0 work_type 0 Residence_type 0 avg_glucose_level 0 bmi 0 smoking_status 0 stroke 0 Stroke? 0 dtype: int64
# Check for values of the smoking status variable:
data['smoking_status'].value_counts()
never smoked 1892 Unknown 1544 formerly smoked 885 smokes 789 Name: smoking_status, dtype: int64
# Check for values of the gender variable:
data['gender'].value_counts()
Female 2994 Male 2115 Other 1 Name: gender, dtype: int64
# Drop "id" column, "stroke?" column and the instance where the gender variable has "Other" value
data = data.drop(columns=['id', 'Stroke?'])
data = data[data.gender != 'Other']
# Encode variables as binaries
data['ever_married'] = data['ever_married'].replace({'No': 0, 'Yes': 1})
data['Residence_type'] = data['Residence_type'].replace({'Rural': 0, 'Urban': 1})
data['gender'] = data['gender'].replace({'Male': 1, 'Female': 0})
data.head()
| gender | age | hypertension | heart_disease | ever_married | work_type | Residence_type | avg_glucose_level | bmi | smoking_status | stroke | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 67.0 | 0 | 1 | 1 | Private | 1 | 228.69 | 36.600000 | formerly smoked | 1 |
| 1 | 0 | 61.0 | 0 | 0 | 1 | Self_employed | 0 | 202.21 | 28.893237 | never smoked | 1 |
| 2 | 1 | 80.0 | 0 | 1 | 1 | Private | 0 | 105.92 | 32.500000 | never smoked | 1 |
| 3 | 0 | 49.0 | 0 | 0 | 1 | Private | 1 | 171.23 | 34.400000 | smokes | 1 |
| 4 | 0 | 79.0 | 1 | 0 | 1 | Self_employed | 0 | 174.12 | 24.000000 | never smoked | 1 |
work=pd.get_dummies(data['work_type'])
smoke=pd.get_dummies(data['smoking_status'],drop_first=True)
data = data.join([work, smoke])
data.head()
| gender | age | hypertension | heart_disease | ever_married | work_type | Residence_type | avg_glucose_level | bmi | smoking_status | stroke | Goverment_job | Never_worked | Parental_leave | Private | Self_employed | formerly smoked | never smoked | smokes | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 67.0 | 0 | 1 | 1 | Private | 1 | 228.69 | 36.600000 | formerly smoked | 1 | 0 | 0 | 0 | 1 | 0 | 1 | 0 | 0 |
| 1 | 0 | 61.0 | 0 | 0 | 1 | Self_employed | 0 | 202.21 | 28.893237 | never smoked | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 1 | 0 |
| 2 | 1 | 80.0 | 0 | 1 | 1 | Private | 0 | 105.92 | 32.500000 | never smoked | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 1 | 0 |
| 3 | 0 | 49.0 | 0 | 0 | 1 | Private | 1 | 171.23 | 34.400000 | smokes | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 |
| 4 | 0 | 79.0 | 1 | 0 | 1 | Self_employed | 0 | 174.12 | 24.000000 | never smoked | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 1 | 0 |
#Rename columns:
data = data.rename(columns={'formerly smoked': 'formerly_smoked', 'never smoked': 'never_smoked'})
data = data.drop(columns=['smoking_status', 'work_type'])
data.head()
| gender | age | hypertension | heart_disease | ever_married | Residence_type | avg_glucose_level | bmi | stroke | Goverment_job | Never_worked | Parental_leave | Private | Self_employed | formerly_smoked | never_smoked | smokes | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 67.0 | 0 | 1 | 1 | 1 | 228.69 | 36.600000 | 1 | 0 | 0 | 0 | 1 | 0 | 1 | 0 | 0 |
| 1 | 0 | 61.0 | 0 | 0 | 1 | 0 | 202.21 | 28.893237 | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 1 | 0 |
| 2 | 1 | 80.0 | 0 | 1 | 1 | 0 | 105.92 | 32.500000 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 1 | 0 |
| 3 | 0 | 49.0 | 0 | 0 | 1 | 1 | 171.23 | 34.400000 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 |
| 4 | 0 | 79.0 | 1 | 0 | 1 | 0 | 174.12 | 24.000000 | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 1 | 0 |
data['stroke'].value_counts().plot(kind='bar')
print(data['stroke'].value_counts())
plt.title('Distribution by "stroke"', fontsize=12)
0 4860 1 249 Name: stroke, dtype: int64
Text(0.5, 1.0, 'Distribution by "stroke"')
The model trained on the defaut version of the dataset gives accuracy 95% with the low (0.05) f1-score for predicting high risk of the stroke. This is very poor f1-score. Therefore, I decided to implement either oversamlpling or downsampling of the dataset.
In this case, however, oversampling based on 249 rows of "stroke cases" can cause the overfitting, specially for higher over-sampling rates, and decrease the classifier performance (Branco, Torgo and Ribeiro, 2015).
So, I decided to undersample the dataset. Making set perfectly balanced, however, can lead to worse prediction of no-stroke cases, so I undersampled the dataset to the 1109 entries and I still use the tecniques which are pretty robust when dealing with imbalanced data:
#Let us remove 4000 rows from 4860 rows (to the estimation 860 for no stroke vs 249 stroke)
remove_number = 4000
strokes_false = data.loc[data['stroke'] == 0]
strokes_false
drop = np.random.choice(strokes_false.index, remove_number, replace=False)
df_downsampled = data.drop(drop)
df_downsampled
| gender | age | hypertension | heart_disease | ever_married | Residence_type | avg_glucose_level | bmi | stroke | Goverment_job | Never_worked | Parental_leave | Private | Self_employed | formerly_smoked | never_smoked | smokes | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 67.0 | 0 | 1 | 1 | 1 | 228.69 | 36.600000 | 1 | 0 | 0 | 0 | 1 | 0 | 1 | 0 | 0 |
| 1 | 0 | 61.0 | 0 | 0 | 1 | 0 | 202.21 | 28.893237 | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 1 | 0 |
| 2 | 1 | 80.0 | 0 | 1 | 1 | 0 | 105.92 | 32.500000 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 1 | 0 |
| 3 | 0 | 49.0 | 0 | 0 | 1 | 1 | 171.23 | 34.400000 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 |
| 4 | 0 | 79.0 | 1 | 0 | 1 | 0 | 174.12 | 24.000000 | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 1 | 0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 5092 | 1 | 76.0 | 0 | 0 | 1 | 1 | 82.35 | 38.900000 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 |
| 5094 | 1 | 13.0 | 0 | 0 | 0 | 1 | 82.38 | 24.300000 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 |
| 5101 | 0 | 45.0 | 0 | 0 | 1 | 1 | 97.95 | 24.500000 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 |
| 5105 | 0 | 80.0 | 1 | 0 | 1 | 1 | 83.75 | 28.893237 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 1 | 0 |
| 5108 | 1 | 51.0 | 0 | 0 | 1 | 0 | 166.29 | 25.600000 | 0 | 0 | 0 | 0 | 1 | 0 | 1 | 0 | 0 |
1109 rows × 17 columns
df_downsampled['stroke'].value_counts().plot(kind='bar')
print(df_downsampled['stroke'].value_counts())
plt.title('Distribution by stroke after downsampling:', fontsize=12)
0 860 1 249 Name: stroke, dtype: int64
Text(0.5, 1.0, 'Distribution by stroke after downsampling:')
# Split the downsampled dataset into the test and the train subsets:
X = df_downsampled.drop(columns=['stroke'])
y = df_downsampled['stroke']
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=.2, random_state= 100)
Implementation
For implementing the project, I trained five machine learning models. These models are:
k-nearest neighbors classifier, decision tree, random forest, gaussian naive bayes and c-support vector classification.
The choice of these models is based on them being commonly used for implementing machine learning tasks. The k-nearest neighbors classifier is based on the distance between data points and aims to cluster together points with small distance from each other. The decision tree model is based on learning decision rules based on the data features. The random forest model follows the same principle, but relies on multiple decision trees. The gaussian naive bayes model is based on applying Bayes’ theorem, whereas c-support vector classification is a form of support vector machines model.
These models follow different principles so it is important to identify which model is the most appropriate for the task. To do so, I train all five models and then compare their performance using the same set of metrics. As noted above, I see the F1-score metric as a key indicator for the model performance.
Refinement
The examination of the model performance shows that models do not perform optimally, in particular, for predicting class 1 cases (high probability of the stroke). The lowest performance in terms of F1-score is observed for the c-support vector classification model. The highest performance in terms of F1-score for class 1 cases is observed for the gaussian naive bayes model. However, the gaussian naive bayes model has lower accuracy than the k-nearest neighbors classifier and the random forest classifier. The latter two models snow the highest f1-score and accuracy.
models = {
"K-Nearest Neighbors": KNeighborsClassifier(n_neighbors = 4),
"Decision Tree": DecisionTreeClassifier(),
"GaussianNB": GaussianNB(),
"SVC": SVC(),
"Random Forest": RandomForestClassifier(n_estimators = 350)
}
for name, model in models.items():
model.fit(x_train, y_train)
from sklearn.metrics import confusion_matrix, classification_report
def evaluate_model(model):
'''
This function performs evaluation of model performance
'''
prediction = model.predict(x_test)
report = classification_report(y_test, prediction)
print(report)
print("Model performance on test data: \n-----------------")
for name, model in models.items():
print('\n' + 'Model: ' + name + ":")
evaluate_model (model)
print("-----------------------------------------------------")
Model performance on test data:
-----------------
Model: K-Nearest Neighbors:
precision recall f1-score support
0 0.78 0.93 0.85 164
1 0.56 0.24 0.34 58
accuracy 0.75 222
macro avg 0.67 0.59 0.59 222
weighted avg 0.72 0.75 0.71 222
-----------------------------------------------------
Model: Decision Tree:
precision recall f1-score support
0 0.79 0.79 0.79 164
1 0.41 0.41 0.41 58
accuracy 0.69 222
macro avg 0.60 0.60 0.60 222
weighted avg 0.69 0.69 0.69 222
-----------------------------------------------------
Model: GaussianNB:
precision recall f1-score support
0 0.98 0.37 0.53 164
1 0.35 0.98 0.52 58
accuracy 0.53 222
macro avg 0.67 0.67 0.53 222
weighted avg 0.82 0.53 0.53 222
-----------------------------------------------------
Model: SVC:
precision recall f1-score support
0 0.74 1.00 0.85 164
1 0.00 0.00 0.00 58
accuracy 0.74 222
macro avg 0.37 0.50 0.42 222
weighted avg 0.55 0.74 0.63 222
-----------------------------------------------------
Model: Random Forest:
precision recall f1-score support
0 0.78 0.91 0.84 164
1 0.53 0.28 0.36 58
accuracy 0.75 222
macro avg 0.66 0.60 0.60 222
weighted avg 0.72 0.75 0.72 222
-----------------------------------------------------
Model Evaluation and Validation
Based on F1-score metrics as well as accuracy, I decided that k-nearest neighbors classifier and random forest classifier offer the most optimal performance for the task. Between these two models, I chose random forest classifier because of higher F1-score for predicting class 1 observations. I acknowledge that model performance for class 1 (high probability of the stroke) is still suboptimal, but I think it is maximum that can be reached without expanding and improving the dataset.
Justification
The reason why the random forest model is performing the best for this task is, in my view, related to it being robust and non-parametric. Because of these reasons, it is less affected by the outliers. Overall, it is one of the most commonly used machine learning models, so I am generally not surprised by its high performance.
Reflection
In this project, I implementeg five machine learning models and compared their performance to identify which one of them performs the best for predicting the high or low probability of the stroke. Based on my analysis, I found that the random forest model performs the best for this task. In my view, its higher performance is attributed to being more robust.
The aspect of the project which I found more difficult is related to the data preparation. All the models performed poorly on the initial version of the dataset. I had to implement the dataset downsampling in order to improve the model performance.
Improvement
In general, we can say that, first of all, the dataset can be improved. It is not only imbalanced, but also does not differentiate between different types of strokes: ischemic strokes, which mostly caused by blockage of a blood vessel, and hemorrhagic strokes, which can occur after the head trauma (Caceres and Goldstein, 2012). Second, it might be interesting to see how deep learning would perform for the task, albeit I think that the size of the dataset might be too small to achieve high performance of a deep learning model.
# Saving the model for the app:
prediction_model = RandomForestClassifier().fit(x_train, y_train)
#save model as pickle:
import pickle
pickle.dump(prediction_model, open("model.p", 'wb'))
1. Branco, P., Torgo, L. & Ribeiro, R. (2015). A Survey of Predictive Modelling under Imbalanced Distributions. Machine Learning. Retrieved from https://machinelearningmastery.com/random-oversampling-and-undersampling-for-imbalanced-classification/
2. Caceres, J. A., & Goldstein, J. N. (2012). Intracranial hemorrhage. Emergency medicine clinics of North America, 30(3), 771–794. https://doi.org/10.1016/j.emc.2012.06.003
3. Mayo Clinic. (n.d.). Diabetes. Diseases & Conditions. Retrieved from https://www.mayoclinic.org/diseases-conditions/diabetes/diagnosis-treatment/drc-20371451
4. National Heart, Lung, and Blood Institute. (n.d). Stroke. Retrieved from https://www.nhlbi.nih.gov/health-topics/stroke
5. Soriano, F. (2021). Stroke Prediction Dataset. Kaggle. Retrieved from https://www.kaggle.com/fedesoriano/stroke-prediction-dataset
6. Sunastra, M. (2017). Performance metrics for classification problems in machine learning. Medium. Retrieved from https://medium.com/@MohammedS/performance-metrics-for-classification-problems-in-machine-learning-part-i-b085d432082b
7. World Health Organisation. (n.d). Body mass index - BMI. A healthy lifestyle. Retrieved from https://www.euro.who.int/en/health-topics/disease-prevention/nutrition/a-healthy-lifestyle/body-mass-index-bmi